Fine-tune the Hugging Face Vision Transformer(ViT) with Pytorch for Vehicle Type Image Classification, training gradually unfreezing layers strating from the end vs train a model with all layers unfrozen from the get-go
We'll be downloading the image dataset from Roboflow - https://universe.roboflow.com/paul-guerrie-tang1/vehicle-classification-eapcd
The dataset will be saved in the following structure:
# Downloading the dataset
from roboflow import Roboflow
import os
if not os.path.exists("./Vehicle-Classification-1/"):
rf = Roboflow(api_key="QcgQPG8g2tj3r0ottt5l")
project = rf.workspace("paul-guerrie-tang1").project("vehicle-classification-eapcd")
dataset = project.version(1).download("folder")
import torch
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader
from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTImageProcessor
import tqdm as notebook_tqdm
import pytorch_lightning as pl
from torchmetrics import Accuracy
import numpy as np
import warnings
warnings.filterwarnings('ignore')
# Assigning device based on windows or macOS
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print("Running on", device)
Running on cuda
# Creating a directory path to our dataset
data_dir = Path('Vehicle-Classification-1')
train_data_dir = Path('Vehicle-Classification-1/train')
val_data_dir = Path('Vehicle-Classification-1/valid')
test_data_dir = Path('Vehicle-Classification-1/test')
ds = ImageFolder(data_dir)
# Assigning Train, Test and Valid images to their respective set
train_ds = ImageFolder(train_data_dir)
val_ds = ImageFolder(val_data_dir)
test_ds = ImageFolder(test_data_dir)
import os
test_path = './Vehicle-Classification-1/test/'
plt.figure(figsize=(60, 50))
classes = os.listdir(test_path)[:-1]
for i, class_folder in enumerate(classes):
image_name = os.listdir(os.path.join(test_path, class_folder))[0]
plt.subplot(4, len(train_ds.classes)//4, i+1)
ax = plt.gca()
ax.set_title(
class_folder,
size='xx-large',
pad=2,
loc='left',
y=0,
backgroundcolor='white'
)
ax.axis('off')
image = Image.open(os.path.join(test_path, class_folder, image_name))
plt.imshow(image)
plt.axis('off')